load("fig7_input_image.rda")

library(tidyverse)
library(lubridate)
library(splitstackshape)
library(scales)
library(grid)
library(tidyverse)
library(glue)
library(viridis)
library(grid)
library(lemon)

ct <- read_csv("posterior_collapsed_trees_sites/Collapsed_Tree_Posterior_sites_10percol_1.csv")

hosts <- unique(ct$hosts)

hosts <- hosts[order(hosts)]

hosts <- hosts[1:(length(hosts)-1)]

df.for.nodes <- tibble(site = hosts, host = sapply(hosts, function(x) unlist(strsplit(x, "_"))[1]))

df.for.nodes$ward <- "harris"

df.for.nodes$ward[which(startsWith(df.for.nodes$host, "T"))] <- hospital.stays$ward[match(df.for.nodes$host[which(startsWith(df.for.nodes$host, "T"))], hospital.stays$split.host)]

df.for.nodes <- df.for.nodes[order(df.for.nodes$ward),]

df.for.nodes$site <- factor(df.for.nodes$site, levels=unique(df.for.nodes$site))
df.for.nodes$host <- factor(df.for.nodes$host, levels=unique(df.for.nodes$host))
df.for.nodes$single.host <- substr(df.for.nodes$host, 1, 4)
df.for.nodes$single.host <- factor(df.for.nodes$single.host, levels=unique(df.for.nodes$single.host))

df.for.nodes$trace <- df.for.nodes$host %in% single.sampled

edges.tong <- read_csv("Posterior_sites_10percol_hostRelationshipSummary.csv")

edges.tong <- edges.tong %>% filter(!str_starts(host.1, "S") & !str_starts(host.2, "S")) %>%
  mutate(host.alphabetical.1 = pmin(host.1, host.2)) %>%
  mutate(host.alphabetical.2 = pmax(host.1, host.2)) %>%
  mutate(reverse = host.1 > host.2) %>%
  mutate(ancestry = map2_chr(ancestry, reverse, function(x, y){
    if(x == "trans") {
      ifelse(y, "desc", "anc")
    } else if(x == "multi_trans") {
      ifelse(y, "multi_desc", "multi_anc")
    } else {
      x
    }
  })) %>%
  select(host.alphabetical.1, host.alphabetical.2, ancestry, ancestry.tree.count) %>%
  filter(ancestry != "none") %>%
  rename(host.1 = host.alphabetical.1, host.2 = host.alphabetical.2) %>%
  spread(ancestry, ancestry.tree.count, 0) %>%
  mutate(anc = anc + multi_anc, desc = desc + multi_desc) %>%
  select(host.1, host.2, anc, desc, complex) %>%
  mutate(total = anc + desc + complex) %>%
  filter(total >= 50) %>%
  mutate(anc.plus = anc + 0.5*complex) %>%
  mutate(desc.plus = desc + 0.5*complex) %>%
  mutate(anc = anc/100, desc = desc/100, complex = complex/100, total = total/100)

nodes.tong <- df.for.nodes[which(startsWith(as.character(df.for.nodes$single.host), "T")),]

library(igraph)
library(intergraph)
library(network)
library(GGally)

ig <- graph_from_data_frame(edges.tong, directed = F, nodes.tong)

network.rep <- asNetwork(ig)

arrangement <- read_csv("figure_7S1_node_arrangement.csv")

nodes.tong$x <- sapply(nodes.tong$site, function(x) arrangement$x[match(x, arrangement$label)])
nodes.tong$y <- sapply(nodes.tong$site, function(x) arrangement$y[match(x, arrangement$label)])

nodes.tong$ward <- factor(nodes.tong$ward, levels=c("PICU 2", "surg", "hcw"))

nodes.tong$nice.label <- substr(nodes.tong$site, 2, nchar(as.character(nodes.tong$site)))
nodes.tong$nice.label <- paste0("C", nodes.tong$nice.label)
nodes.tong$nice.label <- gsub("_", "\n", nodes.tong$nice.label)
nodes.tong$nice.label[which(nodes.tong$host=="T056")] <- "C056\nF"


nodes.tong$wardtrace <-paste0(nodes.tong$ward, nodes.tong$trace)

nodes.tong$wardtrace <- factor(nodes.tong$wardtrace, levels=c("PICU 2FALSE", "PICU2TRUE", "surgFALSE", "surgTRUE", "hcwFALSE", "hcwTRUE"))

edges.tong$x.start <- sapply(edges.tong$host.1, function(x) arrangement$x[match(x, arrangement$label)])
edges.tong$y.start <- sapply(edges.tong$host.1, function(x) arrangement$y[match(x, arrangement$label)])
edges.tong$x.end <- sapply(edges.tong$host.2, function(x) arrangement$x[match(x, arrangement$label)])
edges.tong$y.end <- sapply(edges.tong$host.2, function(x) arrangement$y[match(x, arrangement$label)])
edges.tong$x.midpoint <- (edges.tong$x.end + edges.tong$x.start)/2
edges.tong$y.midpoint <- (edges.tong$y.end + edges.tong$y.start)/2

shorten <- function(x.start, x.end, y.start, y.end, shortening){
  length <- sqrt((x.end - x.start)^2 + (y.end - y.start)^2)
  if(length < 2*shortening){
    stop("No line remains after shortening")
  }
  relative.start <- shortening/length
  relative.end <- 1-relative.start
  a.vector = c(x.end - x.start, y.end-y.start)
  new.start <- c(x.start, y.start) + relative.start*a.vector
  new.end <- c(x.start, y.start) + relative.end*a.vector
  return(list(x.start.shortened = new.start[1], x.end.shortened = new.end[1], y.start.shortened = new.start[2], y.end.shortened = new.end[2]))
}

# this function shifts a line perpendicularly a given distance

nudge <- function(x.start, x.end, y.start, y.end, nudge.distance){
  length <- sqrt((x.end - x.start)^2 + (y.end - y.start)^2)
  a.vector = c(x.end - x.start, y.end-y.start)
  perpendicular.unit.vector <- c(y.start-y.end, x.end - x.start)/length
  new.start <- c(x.start, y.start) + nudge.distance*perpendicular.unit.vector
  new.end <- c(x.end, y.end) + nudge.distance*perpendicular.unit.vector
  return(list(x.start = new.start[1],  x.end = new.end[1], y.start = new.start[2], y.end = new.end[2]))
}

# this function finds a point to put the end of a half-arrowhead. The angle is in _radians_ obviously

half.arrowhead <- function(x.start, x.end, y.start, y.end, angle, length, start = T){
  line.length <- sqrt((x.end - x.start)^2 + (y.end - y.start)^2)
  a.vector = c(x.end - x.start, y.end-y.start)
  unrotated.vector <- a.vector*(length/line.length)
  rotated.vector <- c(cos(angle)*unrotated.vector[1] - sin(angle)*unrotated.vector[2], sin(angle)*unrotated.vector[1] + cos(angle)*unrotated.vector[2])
  if(start){
    return(list(x= x.start + rotated.vector[1], y= y.start + rotated.vector[2]))
  } else {
    return(list(x= x.end - rotated.vector[1], y= y.end - rotated.vector[2]))
  }
}

shortened.columns <- edges.tong %>% 
  mutate(shortened.coords = pmap(list(x.start, x.end, y.start, y.end), function(x1, x2, y1, y2){
    shorten(x1, x2, y1, y2, 0.018)
  })) %>%
  pull(shortened.coords) %>%
  transpose() %>%
  map_df(unlist)

edges.tong <- bind_cols(edges.tong, shortened.columns)

nudge.length <- 0.0035

left.nudge.columns <- edges.tong %>% 
  mutate(nudged.coords = pmap(list(x.start.shortened, x.end.shortened, y.start.shortened, y.end.shortened), function(x1, x2, y1, y2){
    nudge(x1, x2, y1, y2, nudge.length)
  })) %>%
  pull(nudged.coords) %>%
  transpose() %>%
  map_df(unlist)

colnames(left.nudge.columns) <- glue("{colnames(left.nudge.columns)}.ln")

right.nudge.columns <- edges.tong %>% 
  mutate(nudged.coords = pmap(list(x.start.shortened, x.end.shortened, y.start.shortened, y.end.shortened), function(x1, x2, y1, y2){
    nudge(x1, x2, y1, y2, -nudge.length)
  })) %>%
  pull(nudged.coords) %>%
  transpose() %>%
  map_df(unlist)

colnames(right.nudge.columns) <- glue("{colnames(right.nudge.columns)}.rn")

edges.tong <- bind_cols(edges.tong, left.nudge.columns, right.nudge.columns)

arrow.angle <- 2*pi - pi/8
arrow.length <- 0.01

start.arrow.columns <- edges.tong %>% 
  mutate(arrowhead.end.coords = pmap(list(x.start.rn, x.end.rn, y.start.rn, y.end.rn), function(x1, x2, y1, y2){
    half.arrowhead(x1, x2, y1, y2, arrow.angle, arrow.length, T)
  })) %>%
  pull(arrowhead.end.coords) %>%
  transpose() %>%
  map_df(unlist)

colnames(start.arrow.columns) <- glue("{colnames(start.arrow.columns)}.sae")

end.arrow.columns <- edges.tong %>% 
  mutate(arrowhead.end.coords = pmap(list(x.start.ln, x.end.ln, y.start.ln, y.end.ln), function(x1, x2, y1, y2){
    half.arrowhead(x1, x2, y1, y2, arrow.angle, arrow.length, F)
  })) %>%
  pull(arrowhead.end.coords) %>%
  transpose() %>%
  map_df(unlist)

colnames(end.arrow.columns) <- glue("{colnames(end.arrow.columns)}.eae")

edges.tong <- bind_cols(edges.tong, start.arrow.columns, end.arrow.columns)

sp.edges.tong <- edges.tong[which(substr(edges.tong$host.1, 1, 4) == substr(edges.tong$host.2, 1, 4)),]
sp.edges.tong$ward <- nodes.tong$ward[match(sp.edges.tong$host.1, nodes.tong$site)] 

out.diagram <- ggplot() + 
  geom_segment(data = sp.edges.tong, aes(x=x.start, xend = x.end, y=y.start, yend = y.end), col = "black", size=18, lineend="round") +
  geom_segment(data = sp.edges.tong, aes(x=x.start, xend = x.end, y=y.start, yend = y.end), col = "grey76", size=14, lineend="round") +
  geom_segment(data = edges.tong, aes(x = x.start.shortened, xend = x.end.shortened, y= y.start.shortened, yend = y.end.shortened),  size = 2, lineend="round") +
  geom_segment(data = edges.tong, aes(x = x.start.ln, xend = x.end.ln, y= y.start.ln, yend = y.end.ln), size = 2, lineend="round") +
  geom_segment(data = edges.tong, aes(x = x.start.rn, xend = x.end.rn, y= y.start.rn, yend = y.end.rn), size = 2, lineend="round") +
  geom_segment(data = edges.tong, aes(x = x.end.ln, xend = x.eae, y= y.end.ln, yend = y.eae), size=2, lineend="round") +
  geom_segment(data = edges.tong, aes(x = x.start.rn, xend = x.sae, y= y.start.rn, yend = y.sae), size=2, lineend="round") +
  geom_segment(data = edges.tong, aes(x = x.start.shortened, xend = x.end.shortened, y= y.start.shortened, yend = y.end.shortened, col = complex),  size = 1.5, lineend="round") +
  geom_segment(data = edges.tong, aes(x = x.start.ln, xend = x.end.ln, y= y.start.ln, yend = y.end.ln, col = anc), size = 1.5, lineend="round") +
  geom_segment(data = edges.tong, aes(x = x.start.rn, xend = x.end.rn, y= y.start.rn, yend = y.end.rn, col = desc), size = 1.5, lineend="round") +
  geom_segment(data = edges.tong, aes(x = x.end.ln, xend = x.eae, y= y.end.ln, yend = y.eae, col = anc), size=1.5, lineend="round") +
  geom_segment(data = edges.tong, aes(x = x.start.rn, xend = x.sae, y= y.start.rn, yend = y.sae, col = desc), size=1.5, lineend="round") +
  geom_point(data=nodes.tong, aes(x=x, y=y, alpha = trace, fill=ward), shape=21, size=16, stroke=4) +
  scale_size_continuous(limits = c(0,100), range = c(0.5, 2.5)) +
  scale_fill_manual(values = c("darkred", "darkgreen", "darkblue"), labels=c("Paediatric patient", "General surgery patient", "HCW"), name="Subject") +
  scale_alpha_manual(values=c(1,0.5), labels=c("Non-trace", "Trace"), name="Colonisation type") +
  scale_colour_viridis(direction = -1, limits = c(0,1), option = "magma", name = "Posterior support for\ntopological relationship") +
  geom_label(data=edges.tong, aes(x=x.midpoint, y=y.midpoint, label=total), size=4, label.padding = unit(0.15, "lines")) + 
  geom_text(data=nodes.tong, aes(x=x, y=y, label=nice.label), col="white", size=3.5) + 
  theme_void()

out.diagram + theme(legend.position = "none") 

ggsave(file = "Figure7S1.pdf", width=22, height=22)

pdf(file="Figure7S1Legend.pdf", width=5, height=12)

mylegend <- g_legend(out.diagram)
grid.draw(mylegend)
dev.off()

reduced.nodes.tong <- nodes.tong[which(!(nodes.tong$trace)),]
reduced.edges.tong <- edges.tong[which((edges.tong$host.1 %in% reduced.nodes.tong$site) & (edges.tong$host.2 %in% reduced.nodes.tong$site)), 1:8]

arrangement <- read_csv("figure_7_node_arrangement.csv")

reduced.nodes.tong$x <- sapply(reduced.nodes.tong$site, function(x) arrangement$x[match(x, arrangement$label)])
reduced.nodes.tong$y <- sapply(reduced.nodes.tong$site, function(x) arrangement$y[match(x, arrangement$label)])

reduced.edges.tong$x.start <- sapply(reduced.edges.tong$host.1, function(x) arrangement$x[match(x, arrangement$label)]) 
reduced.edges.tong$y.start <- sapply(reduced.edges.tong$host.1, function(x) arrangement$y[match(x, arrangement$label)]) 
reduced.edges.tong$x.end <- sapply(reduced.edges.tong$host.2, function(x) arrangement$x[match(x, arrangement$label)]) 
reduced.edges.tong$y.end <- sapply(reduced.edges.tong$host.2, function(x) arrangement$y[match(x, arrangement$label)]) 
reduced.edges.tong$x.midpoint <- (reduced.edges.tong$x.end + reduced.edges.tong$x.start)/2
reduced.edges.tong$y.midpoint <- (reduced.edges.tong$y.end + reduced.edges.tong$y.start)/2

shortened.columns <- reduced.edges.tong %>% 
  mutate(shortened.coords = pmap(list(x.start, x.end, y.start, y.end), function(x1, x2, y1, y2){
    shorten(x1, x2, y1, y2, 0.022)
  })) %>%
  pull(shortened.coords) %>%
  transpose() %>%
  map_df(unlist)

reduced.edges.tong <- bind_cols(reduced.edges.tong, shortened.columns)

nudge.length <- 0.0035

left.nudge.columns <- reduced.edges.tong %>% 
  mutate(nudged.coords = pmap(list(x.start.shortened, x.end.shortened, y.start.shortened, y.end.shortened), function(x1, x2, y1, y2){
    nudge(x1, x2, y1, y2, nudge.length)
  })) %>%
  pull(nudged.coords) %>%
  transpose() %>%
  map_df(unlist)

colnames(left.nudge.columns) <- glue("{colnames(left.nudge.columns)}.ln")

right.nudge.columns <- reduced.edges.tong %>% 
  mutate(nudged.coords = pmap(list(x.start.shortened, x.end.shortened, y.start.shortened, y.end.shortened), function(x1, x2, y1, y2){
    nudge(x1, x2, y1, y2, -nudge.length)
  })) %>%
  pull(nudged.coords) %>%
  transpose() %>%
  map_df(unlist)

colnames(right.nudge.columns) <- glue("{colnames(right.nudge.columns)}.rn")

reduced.edges.tong <- bind_cols(reduced.edges.tong, left.nudge.columns, right.nudge.columns)

arrow.angle <- 2*pi - pi/8
arrow.length <- 0.01

start.arrow.columns <- reduced.edges.tong %>% 
  mutate(arrowhead.end.coords = pmap(list(x.start.rn, x.end.rn, y.start.rn, y.end.rn), function(x1, x2, y1, y2){
    half.arrowhead(x1, x2, y1, y2, arrow.angle, arrow.length, T)
  })) %>%
  pull(arrowhead.end.coords) %>%
  transpose() %>%
  map_df(unlist)

colnames(start.arrow.columns) <- glue("{colnames(start.arrow.columns)}.sae")

end.arrow.columns <- reduced.edges.tong %>% 
  mutate(arrowhead.end.coords = pmap(list(x.start.ln, x.end.ln, y.start.ln, y.end.ln), function(x1, x2, y1, y2){
    half.arrowhead(x1, x2, y1, y2, arrow.angle, arrow.length, F)
  })) %>%
  pull(arrowhead.end.coords) %>%
  transpose() %>%
  map_df(unlist)

colnames(end.arrow.columns) <- glue("{colnames(end.arrow.columns)}.eae")

reduced.edges.tong <- bind_cols(reduced.edges.tong, start.arrow.columns, end.arrow.columns)


sp.reduced.edges.tong <- reduced.edges.tong[which(substr(reduced.edges.tong$host.1, 1, 4) == substr(reduced.edges.tong$host.2, 1, 4)),]
sp.reduced.edges.tong$ward <- reduced.nodes.tong$ward[match(sp.edges.tong$host.1, reduced.nodes.tong$site)] 

out.diagram <- ggplot() + 
  geom_segment(data = sp.reduced.edges.tong, aes(x=x.start, xend = x.end, y=y.start, yend = y.end), col = "black", size=18, lineend="round") +
  geom_segment(data = sp.reduced.edges.tong, aes(x=x.start, xend = x.end, y=y.start, yend = y.end), col = "grey76", size=14, lineend="round") +
  geom_segment(data = reduced.edges.tong, aes(x = x.start.shortened, xend = x.end.shortened, y= y.start.shortened, yend = y.end.shortened),  size = 2, lineend="round") +
  geom_segment(data = reduced.edges.tong, aes(x = x.start.ln, xend = x.end.ln, y= y.start.ln, yend = y.end.ln), size = 2, lineend="round") +
  geom_segment(data = reduced.edges.tong, aes(x = x.start.rn, xend = x.end.rn, y= y.start.rn, yend = y.end.rn), size = 2, lineend="round") +
  geom_segment(data = reduced.edges.tong, aes(x = x.end.ln, xend = x.eae, y= y.end.ln, yend = y.eae), size=2, lineend="round") +
  geom_segment(data = reduced.edges.tong, aes(x = x.start.rn, xend = x.sae, y= y.start.rn, yend = y.sae), size=2, lineend="round") +
  geom_segment(data = reduced.edges.tong, aes(x = x.start.shortened, xend = x.end.shortened, y= y.start.shortened, yend = y.end.shortened, col = complex),  size = 1.5, lineend="round") +
  geom_segment(data = reduced.edges.tong, aes(x = x.start.ln, xend = x.end.ln, y= y.start.ln, yend = y.end.ln, col = anc), size = 1.5, lineend="round") +
  geom_segment(data = reduced.edges.tong, aes(x = x.start.rn, xend = x.end.rn, y= y.start.rn, yend = y.end.rn, col = desc), size = 1.5, lineend="round") +
  geom_segment(data = reduced.edges.tong, aes(x = x.end.ln, xend = x.eae, y= y.end.ln, yend = y.eae, col = anc), size=1.5, lineend="round") +
  geom_segment(data = reduced.edges.tong, aes(x = x.start.rn, xend = x.sae, y= y.start.rn, yend = y.sae, col = desc), size=1.5, lineend="round") +
  geom_point(data=reduced.nodes.tong, aes(x=x, y=y, fill=ward), shape=21, size=16, stroke=4) +
  scale_size_continuous(limits = c(0,100), range = c(0.5, 2.5)) +
  scale_fill_manual(values = c("darkred", "darkgreen", "darkblue"), labels=c("Paediatric patient", "General surgery patient", "HCW"), name="Subject") +
  scale_colour_viridis(direction = -1, limits = c(0,1), option = "magma", name = "Posterior support for\ntopological relationship") +
  geom_label(data=reduced.edges.tong, aes(x=x.midpoint, y=y.midpoint, label=total), size=4, label.padding = unit(0.15, "lines")) + 
  geom_text(data=reduced.nodes.tong, aes(x=x, y=y, label=nice.label), col="white", size=3.5) + 
  theme_void()

out.diagram + theme(legend.position = "none") 

ggsave(file = "Figure7.pdf", width=18, height=18)

pdf(file="Figure7Legend.pdf", width=5, height=12)
mylegend <- g_legend(out.diagram)
grid.draw(mylegend)
dev.off()